{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Auto Merging Retrieval"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we showcase our AutoMergingRetriever, which looks at a set of leaf nodes and recursively \"merges\" subsets of leaf nodes that reference a parent node beyond a given threshold. This allows us to consolidate potentially disparate, smaller contexts into a larger context that might help synthesis.\n",
    "\n",
    "You can define this hierarchy yourself over a set of documents, or you can make use of our brand-new text parser: a HierarchicalNodeParser that takes in a candidate set of documents and outputs an entire hierarchy of nodes, from \"coarse-to-fine\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install llama-index\n",
    "# !pip install llama-index-vector-stores-qdrant llama-index-readers-file llama-index-embeddings-fastembed llama-index-llms-openai\n",
    "# !pip install -U qdrant_client fastembed\n",
    "# !pip install python-dotenv\n",
    "# !pip install ragas\n",
    "# !pip install trulens_eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "13e5bfc1045040728f8cc7b53f448720",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import logging\n",
    "import sys\n",
    "import os\n",
    "from dotenv import load_dotenv\n",
    "from IPython.display import Markdown, display\n",
    "\n",
    "# qdrant official client\n",
    "import qdrant_client\n",
    "\n",
    "# LLama-index dependencies\n",
    "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader\n",
    "from llama_index.core import SimpleDirectoryReader\n",
    "from llama_index.vector_stores.qdrant import QdrantVectorStore\n",
    "from llama_index.embeddings.fastembed import FastEmbedEmbedding\n",
    "from llama_index.core import Settings\n",
    "\n",
    "# setting the embedding model to BAAI/bge-base-en-v1.5 and FastEmbed to inference these models\n",
    "# Settings.embed_model = FastEmbedEmbedding(model_name=\"BAAI/bge-base-en-v1.5\")\n",
    "Settings.embed_model = FastEmbedEmbedding(model_name=\"BAAI/bge-base-en-v1.5\")\n",
    "# embed_model = FastEmbedEmbedding(model_name=\"BAAI/bge-base-en-v1.5\" , max_length=1024)\n",
    "\n",
    "# load all environment variables\n",
    "load_dotenv()\n",
    "OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n",
    "QDRANT_CLOUD_ENDPOINT = os.getenv(\"QDRANT_CLOUD_ENDPOINT\")\n",
    "QDRANT_API_KEY = os.getenv(\"QDRANT_API_KEY\")\n",
    "\n",
    "# Optional\n",
    "EVAL_DB_URL = os.getenv(\"EVAL_DB_URL\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading files: 100%|██████████| 1/1 [00:00<00:00, 221.35file/s]\n",
      "Loading files: 100%|██████████| 1/1 [00:00<00:00, 221.35file/s]\n"
     ]
    }
   ],
   "source": [
    "# lets loading the documents using SimpleDirectoryReader\n",
    "from llama_index.core import Document\n",
    "reader = SimpleDirectoryReader(\"./data/69_markdown_test/\" , recursive=True)\n",
    "documents = reader.load_data(show_progress=True)\n",
    "\n",
    "# combining all the documents into a single document for later chunking and splitting\n",
    "documents = Document(text=\"\\n\\n\".join([doc.text for doc in documents]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up Vector Database\n",
    "\n",
    "We will be using qDrant as the Vector database\n",
    "There are 4 ways to initialize qdrant \n",
    "\n",
    "1. Inmemory\n",
    "```python\n",
    "client = qdrant_client.QdrantClient(location=\":memory:\")\n",
    "```\n",
    "2. Disk\n",
    "```python\n",
    "client = qdrant_client.QdrantClient(path=\"./data\")\n",
    "```\n",
    "3. Self hosted or Docker\n",
    "```python\n",
    "\n",
    "client = qdrant_client.QdrantClient(\n",
    "    # url=\"http://<host>:<port>\"\n",
    "    host=\"localhost\",port=6333\n",
    ")\n",
    "```\n",
    "\n",
    "4. Qdrant cloud\n",
    "```python\n",
    "client = qdrant_client.QdrantClient(\n",
    "    url=QDRANT_CLOUD_ENDPOINT,\n",
    "    api_key=QDRANT_API_KEY,\n",
    ")\n",
    "```\n",
    "\n",
    "for this notebook we will be using qdrant cloud"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# creating a qdrant client instance\n",
    "\n",
    "client = qdrant_client.QdrantClient(\n",
    "    # you can use :memory: mode for fast and light-weight experiments,\n",
    "    # it does not require to have Qdrant deployed anywhere\n",
    "    # but requires qdrant-client >= 1.1.1\n",
    "    # location=\":memory:\"\n",
    "    # otherwise set Qdrant instance address with:\n",
    "    url=QDRANT_CLOUD_ENDPOINT,\n",
    "    # otherwise set Qdrant instance with host and port:\n",
    "    # host=\"localhost\",\n",
    "    # port=6333\n",
    "    # set API KEY for Qdrant Cloud\n",
    "    api_key=QDRANT_API_KEY,\n",
    "    # path=\"./db/\"\n",
    ")\n",
    "\n",
    "vector_store = QdrantVectorStore(client=client, collection_name=\"4_Auto_merging_retrieval_RAG\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.node_parser import HierarchicalNodeParser\n",
    "from llama_index.core.node_parser import get_leaf_nodes, get_root_nodes\n",
    "\n",
    "\n",
    "chunk_sizes = [2048, 512, 128]\n",
    "node_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=chunk_sizes)\n",
    "nodes = node_parser.get_nodes_from_documents([documents])\n",
    "leaf_nodes = get_leaf_nodes(nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define storage context\n",
    "from llama_index.core.storage.docstore import SimpleDocumentStore\n",
    "from llama_index.core import StorageContext\n",
    "docstore = SimpleDocumentStore()\n",
    "\n",
    "# insert nodes into docstore\n",
    "docstore.add_documents(nodes)\n",
    "\n",
    "# define storage context (will include vector store by default too)\n",
    "storage_context = StorageContext.from_defaults(docstore=docstore)\n",
    "\n",
    "\n",
    "base_index = VectorStoreIndex(\n",
    "    leaf_nodes,\n",
    "    storage_context=storage_context,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.retrievers import AutoMergingRetriever"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_retriever = base_index.as_retriever(similarity_top_k=6)\n",
    "retriever = AutoMergingRetriever(base_retriever, storage_context, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# query_str = \"What were some lessons learned from red-teaming?\"\n",
    "# query_str = \"Can you tell me about the key concepts for safety finetuning\"\n",
    "query_str = (\n",
    "    \"DALVANCE- than comparator-treated subjects with normal baseline transaminase\"\n",
    ")\n",
    "\n",
    "nodes = retriever.retrieve(query_str)\n",
    "base_nodes = base_retriever.retrieve(query_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modify Prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "qa_prompt_str = (\n",
    "    \"Context information is below.\\n\"\n",
    "    \"---------------------\\n\"\n",
    "    \"{context_str}\\n\"\n",
    "    \"---------------------\\n\"\n",
    "    \"Given the context information and not prior knowledge, \"\n",
    "    \"answer the question: {query_str}\\n\"\n",
    ")\n",
    "\n",
    "refine_prompt_str = (\n",
    "    \"We have the opportunity to refine the original answer \"\n",
    "    \"(only if needed) with some more context below.\\n\"\n",
    "    \"------------\\n\"\n",
    "    \"{context_msg}\\n\"\n",
    "    \"------------\\n\"\n",
    "    \"Given the new context, refine the original answer to better \"\n",
    "    \"answer the question: {query_str}. \"\n",
    "    \"If the context isn't useful, output the original answer again.\\n\"\n",
    "    \"Original Answer: {existing_answer}\"\n",
    ")\n",
    "\n",
    "from llama_index.core import ChatPromptTemplate\n",
    "\n",
    "# Text QA Prompt\n",
    "chat_text_qa_msgs = [\n",
    "    (\"system\",\"You are a AI assistant who is well versed with medical information and only answer question per training to the medical domain\"),\n",
    "    (\"user\", qa_prompt_str),\n",
    "]\n",
    "text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)\n",
    "\n",
    "# Refine Prompt\n",
    "chat_refine_msgs = [\n",
    "    (\"system\",\"Always answer the question, even if the context isn't helpful.\",),\n",
    "    (\"user\", refine_prompt_str),\n",
    "]\n",
    "refine_template = ChatPromptTemplate.from_messages(chat_refine_msgs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Final RAG application"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.query_engine import RetrieverQueryEngine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "query_engine = RetrieverQueryEngine.from_args(retriever)\n",
    "base_query_engine = RetrieverQueryEngine.from_args(base_retriever)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "Bronchospasm, a condition characterized by the constriction of the muscles in the airways, can be a symptom of asthma."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "response = query_engine.query(\"How does one get asthma\")\n",
    "display(Markdown(str(response)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performing Evaluation using RAGAS"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating Synthetic Test Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ragas.testset.generator import TestsetGenerator\n",
    "from ragas.testset.evolutions import simple, reasoning, multi_context\n",
    "from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
    "\n",
    "# documents = load your documents\n",
    "\n",
    "# generator with openai models\n",
    "generator_llm = ChatOpenAI(model=\"gpt-3.5-turbo-16k\")\n",
    "critic_llm = ChatOpenAI(model=\"gpt-4\")\n",
    "embeddings = OpenAIEmbeddings()\n",
    "\n",
    "generator = TestsetGenerator.from_langchain(\n",
    "    generator_llm,\n",
    "    critic_llm,\n",
    "    embeddings\n",
    ")\n",
    "\n",
    "# Change resulting question type distribution\n",
    "distributions = {\n",
    "    simple: 0.5,\n",
    "    multi_context: 0.4,\n",
    "    reasoning: 0.1\n",
    "}\n",
    "\n",
    "# use generator.generate_with_llamaindex_docs if you use llama-index as document loader\n",
    "\n",
    "# the document passes here is from the 2nd cell\n",
    "testset = generator.generate_with_llamaindex_docs([documents], 10, distributions) \n",
    "testset = testset.to_pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save it to use with other RAG techniques as well\n",
    "# testset.to_csv('eval_data.csv', index=False)\n",
    "import pandas as pd\n",
    "testset = pd.read_csv('./eval_data.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Evaluation using Truelens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from trulens_eval import Tru\n",
    "\n",
    "\n",
    "# Eval DB URL will be a postgres url\n",
    "if EVAL_DB_URL:\n",
    "    print(\"Connecting to postgres database\")\n",
    "    tru = Tru(database_url=EVAL_DB_URL)\n",
    "else:\n",
    "    print(\"Connecting to local sqlite database\")\n",
    "    tru = Tru()\n",
    "    \n",
    "# # incase you want to resent the database    \n",
    "# tru.reset_database()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ In groundedness_measure_with_cot_reasons, input source will be set to __record__.app.query.rets.source_nodes[:].node.text.collect() .\n",
      "✅ In groundedness_measure_with_cot_reasons, input statement will be set to __record__.main_output or `Select.RecordOutput` .\n",
      "✅ In relevance, input prompt will be set to __record__.main_input or `Select.RecordInput` .\n",
      "✅ In relevance, input response will be set to __record__.main_output or `Select.RecordOutput` .\n",
      "✅ In context_relevance_with_cot_reasons, input question will be set to __record__.main_input or `Select.RecordInput` .\n",
      "✅ In context_relevance_with_cot_reasons, input context will be set to __record__.app.query.rets.source_nodes[:].node.text .\n"
     ]
    }
   ],
   "source": [
    "from trulens_eval.feedback.provider import OpenAI\n",
    "from trulens_eval import Feedback\n",
    "import numpy as np\n",
    "\n",
    "# Initialize provider class\n",
    "provider = OpenAI()\n",
    "\n",
    "# select context to be used in feedback. the location of context is app specific.\n",
    "from trulens_eval.app import App\n",
    "context = App.select_context(query_engine)\n",
    "\n",
    "# Define a groundedness feedback function\n",
    "f_groundedness = (\n",
    "    Feedback(provider.groundedness_measure_with_cot_reasons)\n",
    "    .on(context.collect()) # collect context chunks into a list\n",
    "    .on_output()\n",
    ")\n",
    "\n",
    "# Question/answer relevance between overall question and answer.\n",
    "f_answer_relevance = (\n",
    "    Feedback(provider.relevance)\n",
    "    .on_input_output()\n",
    ")\n",
    "# Question/statement relevance between question and each context chunk.\n",
    "f_context_relevance = (\n",
    "    Feedback(provider.context_relevance_with_cot_reasons)\n",
    "    .on_input()\n",
    "    .on(context)\n",
    "    .aggregate(np.mean)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "from trulens_eval import TruLlama\n",
    "\n",
    "tru_query_engine_recorder = TruLlama(query_engine,app_id=\"4_Auto_merging_retrieval_RAG\",feedbacks=[f_groundedness, f_answer_relevance, f_context_relevance])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_questions = testset['question'].to_list()\n",
    "\n",
    "with tru_query_engine_recorder as recording:\n",
    "    for question in eval_questions:\n",
    "        response = query_engine.query(question)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "records, feedback = tru.get_records_and_feedback(app_ids=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "records.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting dashboard ...\n",
      "Config file already exists. Skipping writing process.\n",
      "Credentials file already exists. Skipping writing process.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3ebd552132694338b4d720e56b10a35c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Accordion(children=(VBox(children=(VBox(children=(Label(value='STDOUT'), Output())), VBox(children=(Label(valu…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dashboard started at http://192.168.1.5:8501 .\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Popen: returncode: None args: ['streamlit', 'run', '--server.headless=True'...>"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tru.run_dashboard()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "tru.stop_dashboard()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
